{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# lora 微调\n",
    "\n",
    "\n",
    "\n",
    "1. 在“LoRA 低秩适配 OpenAI Whisper-Large-V2 语音识别任务”中，为中文语料的训练过程增加过程评估，观察 Train Loss 和 Validation Loss 变化。课程代码（ https://github.com/DjangoPeng/LLM-quickstart/blob/main/peft/peft_lora_whisper-large-v2.ipynb ）\n",
    "\n",
    "2. 在“LoRA 低秩适配 OpenAI Whisper-Large-V2 语音识别任务”中，当 LoRA 模型训练完成后，使用测试集进行完整的模型评估。课程代码（ https://github.com/DjangoPeng/LLM-quickstart/blob/main/peft/peft_lora_whisper-large-v2.ipynb ）\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model_name_or_path = \"/root/autodl-tmp/models/whisper-large-v2\"\n",
    "model_dir = \"/root/autodl-tmp/output/models/whisper-large-v2-asr-int8\"\n",
    "\n",
    "language = \"Chinese (China)\"\n",
    "language_abbr = \"zh-CN\"\n",
    "task = \"transcribe\"\n",
    "dataset_name = \"mozilla-foundation/common_voice_11_0\"\n",
    "\n",
    "batch_size=64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import os\n",
    "\n",
    "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
    "output = result.stdout\n",
    "for line in output.splitlines():\n",
    "    if '=' in line:\n",
    "        var, value = line.split('=', 1)\n",
    "        os.environ[var] = value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'client_id': '95368aab163e0387e4fd4991b4f2d8ccfbd4364bf656c860230501fd27dcedf087773e4695a6cf5de9c4f1d406d582283190d065cdfa36b0e2b060cffaca977e',\n",
       " 'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
       " 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
       "  'array': array([-6.82121026e-13, -2.27373675e-12, -2.27373675e-12, ...,\n",
       "          1.21667399e-05,  3.23003678e-06, -2.43066324e-07]),\n",
       "  'sampling_rate': 48000},\n",
       " 'sentence': '性喜温暖润湿气候且耐寒。',\n",
       " 'up_votes': 2,\n",
       " 'down_votes': 0,\n",
       " 'age': '',\n",
       " 'gender': '',\n",
       " 'accent': '',\n",
       " 'locale': 'zh-CN',\n",
       " 'segment': ''}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset, DatasetDict\n",
    "\n",
    "common_voice = DatasetDict()\n",
    "\n",
    "common_voice[\"train\"] = load_dataset(dataset_name, language_abbr, split=\"train\", trust_remote_code=True)\n",
    "common_voice[\"validation\"] = load_dataset(dataset_name, language_abbr, split=\"validation\", trust_remote_code=True)\n",
    "\n",
    "common_voice[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n",
       "        num_rows: 29056\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n",
       "        num_rows: 10581\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "common_voice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from transformers import AutoFeatureExtractor, AutoTokenizer, AutoProcessor\n",
    "\n",
    "# 从预训练模型加载特征提取器\n",
    "feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)\n",
    "\n",
    "# 从预训练模型加载分词器，可以指定语言和任务以获得最适合特定需求的分词器配置\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    model_name_or_path, language=language, task=task)\n",
    "\n",
    "# 从预训练模型加载处理器，处理器通常结合了特征提取器和分词器，为特定任务提供一站式的数据预处理\n",
    "processor = AutoProcessor.from_pretrained(\n",
    "    model_name_or_path, language=language, task=task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "common_voice = common_voice.remove_columns(\n",
    "    [\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"path\", \"segment\", \"up_votes\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上述代码因为执行过一次，所以第二次就失败，不影响后续流程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
       "  'array': array([ 5.09317033e-11, -7.27595761e-12, -6.54836185e-11, ...,\n",
       "         -5.96661994e-06,  2.71382887e-05,  1.29687978e-05]),\n",
       "  'sampling_rate': 16000},\n",
       " 'sentence': '性喜温暖润湿气候且耐寒。'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "from datasets import Audio\n",
    "\n",
    "common_voice = common_voice.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
    "# sampling_rate 从 48KHZ 降为 16KHZ\n",
    "common_voice[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def prepare_dataset(batch):\n",
    "    audio = batch[\"audio\"]\n",
    "    batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
    "    batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "small_common_voice = DatasetDict()\n",
    "\n",
    "small_common_voice[\"train\"] = common_voice[\"train\"].shuffle(seed=16).select(range(640))\n",
    "small_common_voice[\"validation\"] = common_voice[\"validation\"].shuffle(seed=16).select(range(320))\n",
    "tokenized_common_voice = small_common_voice.map(prepare_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "tokenized_common_voice = common_voice.map(prepare_dataset, num_proc=8)\n",
    "tokenized_common_voice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Union\n",
    "\n",
    "# 定义一个针对语音到文本任务的数据整理器类\n",
    "@dataclass\n",
    "class DataCollatorSpeechSeq2SeqWithPadding:\n",
    "    processor: Any  # 处理器结合了特征提取器和分词器\n",
    "\n",
    "    # 整理器函数，将特征列表处理成一个批次\n",
    "    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
    "        # 从特征列表中提取输入特征，并填充以使它们具有相同的形状\n",
    "        input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
    "        batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
    "\n",
    "        # 从特征列表中提取标签特征（文本令牌），并进行填充\n",
    "        label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
    "        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
    "\n",
    "        # 使用-100替换标签中的填充区域，-100通常用于在损失计算中忽略填充令牌\n",
    "        labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
    "\n",
    "        # 如果批次中的所有序列都以句子开始令牌开头，则移除它\n",
    "        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
    "            labels = labels[:, 1:]\n",
    "\n",
    "        # 将处理过的标签添加到批次中\n",
    "        batch[\"labels\"] = labels\n",
    "\n",
    "        return batch  # 返回最终的批次，准备好进行训练或评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 用给定的处理器实例化数据整理器\n",
    "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/other.py:143: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# 模型准备\n",
    "from transformers import AutoModelForSpeechSeq2Seq\n",
    "from peft import prepare_model_for_int8_training\n",
    "\n",
    "model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path, load_in_8bit=True, device_map=\"auto\")\n",
    "\n",
    "# 设置模型配置中的forced_decoder_ids属性为None\n",
    "model.config.forced_decoder_ids = None  # 这通常用于指定在解码（生成文本）过程中必须使用的特定token的ID，设置为None表示没有这样的强制要求\n",
    "\n",
    "# 设置模型配置中的suppress_tokens列表为空\n",
    "model.config.suppress_tokens = []  # 这用于指定在生成过程中应被抑制（不生成）的token的列表，设置为空列表表示没有要抑制的token\n",
    "\n",
    "\n",
    "model = prepare_model_for_int8_training(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WhisperForConditionalGeneration(\n",
       "  (model): WhisperModel(\n",
       "    (encoder): WhisperEncoder(\n",
       "      (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n",
       "      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n",
       "      (embed_positions): Embedding(1500, 1280)\n",
       "      (layers): ModuleList(\n",
       "        (0-31): 32 x WhisperEncoderLayer(\n",
       "          (self_attn): WhisperSdpaAttention(\n",
       "            (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "            (v_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (q_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "          )\n",
       "          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "          (activation_fn): GELUActivation()\n",
       "          (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
       "          (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (decoder): WhisperDecoder(\n",
       "      (embed_tokens): Embedding(51865, 1280, padding_idx=50257)\n",
       "      (embed_positions): WhisperPositionalEmbedding(448, 1280)\n",
       "      (layers): ModuleList(\n",
       "        (0-31): 32 x WhisperDecoderLayer(\n",
       "          (self_attn): WhisperSdpaAttention(\n",
       "            (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "            (v_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (q_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "          )\n",
       "          (activation_fn): GELUActivation()\n",
       "          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "          (encoder_attn): WhisperSdpaAttention(\n",
       "            (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "            (v_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (q_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "          )\n",
       "          (encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
       "          (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "  )\n",
       "  (proj_out): Linear(in_features=1280, out_features=51865, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 3,932,160 || all params: 1,547,237,120 || trainable%: 0.25414074863974306\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model\n",
    "\n",
    "# 创建一个LoraConfig对象，用于设置LoRA（Low-Rank Adaptation）的配置参数\n",
    "config = LoraConfig(\n",
    "    r=8,  # LoRA的秩，影响LoRA矩阵的大小\n",
    "    lora_alpha=64,  # LoRA适应的比例因子\n",
    "    # 指定将LoRA应用到的模型模块，通常是attention和全连接层的投影。\n",
    "    target_modules=[\"q_proj\", \"v_proj\"],\n",
    "    lora_dropout=0.05,  # 在LoRA模块中使用的dropout率\n",
    "    bias=\"none\",  # 设置bias的使用方式，这里没有使用bias\n",
    ")\n",
    "\n",
    "peft_model = get_peft_model(model, config)\n",
    "peft_model.print_trainable_parameters()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the module from /root/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--wer/85bee9e4216a78bb09b2d0d500f6af5c23da58f9210e661add540f5df6630fcd (last modified on Sat Feb 17 11:14:38 2024) since it couldn't be found locally at evaluate-metric--wer, or remotely on the Hugging Face Hub.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"wer\") \n",
    "def compute_metrics(pred):\n",
    "    pred_ids = pred.predictions\n",
    "    label_ids = pred.label_ids\n",
    "\n",
    "    # replace -100 with the pad_token_id\n",
    "    label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
    "\n",
    "    # we do not want to group tokens when computing the metrics\n",
    "    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
    "    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
    "    print(pred_str)\n",
    "    print(label_str)\n",
    "\n",
    "    wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n",
    "\n",
    "    return {\"wer\": wer}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainingArguments\n",
    "from transformers import Seq2SeqTrainer\n",
    "\n",
    "# 设置序列到序列模型训练的参数\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    output_dir=model_dir,  # 指定模型输出和保存的目录\n",
    "    per_device_train_batch_size=batch_size,  # 每个设备上的训练批量大小\n",
    "    learning_rate=1e-3,  # 学习率\n",
    "    num_train_epochs=3,  # 训练的总轮数\n",
    "    evaluation_strategy=\"epoch\",  # 设置评估策略，这里是在每个epoch结束时进行评估\n",
    "    # warmup_steps=50,  # 在训练初期增加学习率的步数，有助于稳定训练\n",
    "    # fp16=True,  # 启用混合精度训练，可以提高训练速度，同时减少内存使用\n",
    "    per_device_eval_batch_size=batch_size,  # 每个设备上的评估批量大小\n",
    "    generation_max_length=128,  # 生成任务的最大长度\n",
    "    logging_steps=10,  # 指定日志记录的步骤，用于跟踪训练进度\n",
    "    remove_unused_columns=False,  # 是否删除不使用的列，以减少数据处理开销\n",
    "    label_names=[\"labels\"],  # 指定标签列的名称，用于训练过程中\n",
    "    # evaluation_strategy=\"steps\",\n",
    "    # eval_steps=25,\n",
    ")\n",
    "\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    args=training_args,\n",
    "    model=peft_model,\n",
    "    train_dataset=tokenized_common_voice[\"train\"],\n",
    "    eval_dataset=tokenized_common_voice[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=processor.feature_extractor,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "peft_model.config.use_cache = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sat Feb 17 11:52:34 2024       \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |\n",
      "|-----------------------------------------+----------------------+----------------------+\n",
      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                                         |                      |               MIG M. |\n",
      "|=========================================+======================+======================|\n",
      "|   0  NVIDIA GeForce RTX 3090        On  | 00000000:C3:00.0 Off |                  N/A |\n",
      "| 32%   20C    P8              21W / 350W |  14890MiB / 24576MiB |      0%      Default |\n",
      "|                                         |                      |                  N/A |\n",
      "+-----------------------------------------+----------------------+----------------------+\n",
      "                                                                                         \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| Processes:                                                                            |\n",
      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
      "|        ID   ID                                                             Usage      |\n",
      "|=======================================================================================|\n",
      "+---------------------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:61: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='11' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [11/30 02:35 < 05:27, 0.06 it/s, Epoch 1/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>\n",
       "    <div>\n",
       "      \n",
       "      <progress value='5' max='5' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [5/5 00:52]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "TypeError",
     "evalue": "argument 'ids': 'list' object cannot be interpreted as an integer",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[18], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1537\u001b[0m         hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m   1538\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1539\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1540\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1541\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1542\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1544\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:1944\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   1941\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_training_stop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m   1943\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_epoch_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 1944\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1946\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m DebugOption\u001b[38;5;241m.\u001b[39mTPU_METRICS_DEBUG \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdebug:\n\u001b[1;32m   1947\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m is_torch_tpu_available():\n\u001b[1;32m   1948\u001b[0m         \u001b[38;5;66;03m# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2291\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2289\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   2290\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_evaluate:\n\u001b[0;32m-> 2291\u001b[0m     metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2292\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_report_to_hp_search(trial, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step, metrics)\n\u001b[1;32m   2294\u001b[0m     \u001b[38;5;66;03m# Run delayed LR scheduler now that metrics are populated\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:166\u001b[0m, in \u001b[0;36mSeq2SeqTrainer.evaluate\u001b[0;34m(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)\u001b[0m\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather_function \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mgather\n\u001b[1;32m    165\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gen_kwargs \u001b[38;5;241m=\u001b[39m gen_kwargs\n\u001b[0;32m--> 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43meval_dataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:3095\u001b[0m, in \u001b[0;36mTrainer.evaluate\u001b[0;34m(self, eval_dataset, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m   3092\u001b[0m start_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m   3094\u001b[0m eval_loop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprediction_loop \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39muse_legacy_prediction_loop \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mevaluation_loop\n\u001b[0;32m-> 3095\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43meval_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3096\u001b[0m \u001b[43m    \u001b[49m\u001b[43meval_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3097\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdescription\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mEvaluation\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3098\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# No point gathering the predictions if there are no metrics, otherwise we defer to\u001b[39;49;00m\n\u001b[1;32m   3099\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# self.args.prediction_loss_only\u001b[39;49;00m\n\u001b[1;32m   3100\u001b[0m \u001b[43m    \u001b[49m\u001b[43mprediction_loss_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_metrics\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   3101\u001b[0m \u001b[43m    \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3102\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3103\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3105\u001b[0m total_batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39meval_batch_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mworld_size\n\u001b[1;32m   3106\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmetric_key_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_jit_compilation_time\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m output\u001b[38;5;241m.\u001b[39mmetrics:\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:3386\u001b[0m, in \u001b[0;36mTrainer.evaluation_loop\u001b[0;34m(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m   3382\u001b[0m         metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_metrics(\n\u001b[1;32m   3383\u001b[0m             EvalPrediction(predictions\u001b[38;5;241m=\u001b[39mall_preds, label_ids\u001b[38;5;241m=\u001b[39mall_labels, inputs\u001b[38;5;241m=\u001b[39mall_inputs)\n\u001b[1;32m   3384\u001b[0m         )\n\u001b[1;32m   3385\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3386\u001b[0m         metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_metrics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mEvalPrediction\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredictions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mall_preds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mall_labels\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3387\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   3388\u001b[0m     metrics \u001b[38;5;241m=\u001b[39m {}\n",
      "Cell \u001b[0;32mIn[15], line 13\u001b[0m, in \u001b[0;36mcompute_metrics\u001b[0;34m(pred)\u001b[0m\n\u001b[1;32m     10\u001b[0m label_ids[label_ids \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m100\u001b[39m] \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mpad_token_id\n\u001b[1;32m     12\u001b[0m \u001b[38;5;66;03m# we do not want to group tokens when computing the metrics\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m pred_str \u001b[38;5;241m=\u001b[39m \u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m     14\u001b[0m label_str \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(label_ids, skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28mprint\u001b[39m(pred_str)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:3716\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.batch_decode\u001b[0;34m(self, sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)\u001b[0m\n\u001b[1;32m   3692\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbatch_decode\u001b[39m(\n\u001b[1;32m   3693\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m   3694\u001b[0m     sequences: Union[List[\u001b[38;5;28mint\u001b[39m], List[List[\u001b[38;5;28mint\u001b[39m]], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnp.ndarray\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch.Tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtf.Tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   3697\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m   3698\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[\u001b[38;5;28mstr\u001b[39m]:\n\u001b[1;32m   3699\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   3700\u001b[0m \u001b[38;5;124;03m    Convert a list of lists of token ids into a list of strings by calling decode.\u001b[39;00m\n\u001b[1;32m   3701\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   3714\u001b[0m \u001b[38;5;124;03m        `List[str]`: The list of decoded sentences.\u001b[39;00m\n\u001b[1;32m   3715\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m-> 3716\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m   3717\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecode(\n\u001b[1;32m   3718\u001b[0m             seq,\n\u001b[1;32m   3719\u001b[0m             skip_special_tokens\u001b[38;5;241m=\u001b[39mskip_special_tokens,\n\u001b[1;32m   3720\u001b[0m             clean_up_tokenization_spaces\u001b[38;5;241m=\u001b[39mclean_up_tokenization_spaces,\n\u001b[1;32m   3721\u001b[0m             \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m   3722\u001b[0m         )\n\u001b[1;32m   3723\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m sequences\n\u001b[1;32m   3724\u001b[0m     ]\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:3717\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   3692\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbatch_decode\u001b[39m(\n\u001b[1;32m   3693\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m   3694\u001b[0m     sequences: Union[List[\u001b[38;5;28mint\u001b[39m], List[List[\u001b[38;5;28mint\u001b[39m]], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnp.ndarray\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch.Tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtf.Tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   3697\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m   3698\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[\u001b[38;5;28mstr\u001b[39m]:\n\u001b[1;32m   3699\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   3700\u001b[0m \u001b[38;5;124;03m    Convert a list of lists of token ids into a list of strings by calling decode.\u001b[39;00m\n\u001b[1;32m   3701\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   3714\u001b[0m \u001b[38;5;124;03m        `List[str]`: The list of decoded sentences.\u001b[39;00m\n\u001b[1;32m   3715\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m   3716\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[0;32m-> 3717\u001b[0m         \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3718\u001b[0m \u001b[43m            \u001b[49m\u001b[43mseq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3719\u001b[0m \u001b[43m            \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3720\u001b[0m \u001b[43m            \u001b[49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3721\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3722\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3723\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m sequences\n\u001b[1;32m   3724\u001b[0m     ]\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/models/whisper/tokenization_whisper_fast.py:391\u001b[0m, in \u001b[0;36mWhisperTokenizerFast.decode\u001b[0;34m(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, output_offsets, time_precision, decode_with_timestamps, normalize, basic_normalize, remove_diacritics, **kwargs)\u001b[0m\n\u001b[1;32m    351\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    352\u001b[0m \u001b[38;5;124;03mConverts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special\u001b[39;00m\n\u001b[1;32m    353\u001b[0m \u001b[38;5;124;03mtokens and clean up tokenization spaces.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    384\u001b[0m \u001b[38;5;124;03m    `str`: The decoded sentence.\u001b[39;00m\n\u001b[1;32m    385\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    386\u001b[0m filtered_ids \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_preprocess_token_ids(\n\u001b[1;32m    387\u001b[0m     token_ids,\n\u001b[1;32m    388\u001b[0m     skip_special_tokens\u001b[38;5;241m=\u001b[39mskip_special_tokens,\n\u001b[1;32m    389\u001b[0m )\n\u001b[0;32m--> 391\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    392\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfiltered_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    393\u001b[0m \u001b[43m    \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    394\u001b[0m \u001b[43m    \u001b[49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    395\u001b[0m \u001b[43m    \u001b[49m\u001b[43mnormalize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnormalize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    396\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbasic_normalize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbasic_normalize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    397\u001b[0m \u001b[43m    \u001b[49m\u001b[43mremove_diacritics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mremove_diacritics\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    398\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    399\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    400\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m decode_with_timestamps:\n\u001b[1;32m    401\u001b[0m     \u001b[38;5;66;03m# legacy method to decode timestamps when not included in the tokenizer vocabulary\u001b[39;00m\n\u001b[1;32m    402\u001b[0m     text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_decode_with_timestamps(\n\u001b[1;32m    403\u001b[0m         filtered_ids, time_precision\u001b[38;5;241m=\u001b[39mtime_precision, skip_special_tokens\u001b[38;5;241m=\u001b[39mskip_special_tokens\n\u001b[1;32m    404\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:3756\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.decode\u001b[0;34m(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)\u001b[0m\n\u001b[1;32m   3753\u001b[0m \u001b[38;5;66;03m# Convert inputs to python lists\u001b[39;00m\n\u001b[1;32m   3754\u001b[0m token_ids \u001b[38;5;241m=\u001b[39m to_py_obj(token_ids)\n\u001b[0;32m-> 3756\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_decode\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3757\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtoken_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3758\u001b[0m \u001b[43m    \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3759\u001b[0m \u001b[43m    \u001b[49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3760\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3761\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/models/whisper/tokenization_whisper_fast.py:417\u001b[0m, in \u001b[0;36mWhisperTokenizerFast._decode\u001b[0;34m(self, normalize, basic_normalize, remove_diacritics, *args, **kwargs)\u001b[0m\n\u001b[1;32m    414\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_decode\u001b[39m(\n\u001b[1;32m    415\u001b[0m     \u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, normalize: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, basic_normalize: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, remove_diacritics: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    416\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[0;32m--> 417\u001b[0m     text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    419\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m normalize:\n\u001b[1;32m    420\u001b[0m         clean_text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_normalize(text)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/tokenization_utils_fast.py:625\u001b[0m, in \u001b[0;36mPreTrainedTokenizerFast._decode\u001b[0;34m(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)\u001b[0m\n\u001b[1;32m    623\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(token_ids, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m    624\u001b[0m     token_ids \u001b[38;5;241m=\u001b[39m [token_ids]\n\u001b[0;32m--> 625\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_tokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtoken_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    627\u001b[0m clean_up_tokenization_spaces \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m    628\u001b[0m     clean_up_tokenization_spaces\n\u001b[1;32m    629\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m clean_up_tokenization_spaces \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    630\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclean_up_tokenization_spaces\n\u001b[1;32m    631\u001b[0m )\n\u001b[1;32m    632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clean_up_tokenization_spaces:\n",
      "\u001b[0;31mTypeError\u001b[0m: argument 'ids': 'list' object cannot be interpreted as an integer"
     ]
    }
   ],
   "source": [
    "trainer.train()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上述为啥会报错？如何将评估过程加入到训练中？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainingArguments\n",
    "from transformers import Seq2SeqTrainer\n",
    "\n",
    "# 设置序列到序列模型训练的参数\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    output_dir=model_dir,  # 指定模型输出和保存的目录\n",
    "    per_device_train_batch_size=batch_size,  # 每个设备上的训练批量大小\n",
    "    learning_rate=1e-3,  # 学习率\n",
    "    num_train_epochs=3,  # 训练的总轮数\n",
    "    evaluation_strategy=\"epoch\",  # 设置评估策略，这里是在每个epoch结束时进行评估\n",
    "    # warmup_steps=50,  # 在训练初期增加学习率的步数，有助于稳定训练\n",
    "    # fp16=True,  # 启用混合精度训练，可以提高训练速度，同时减少内存使用\n",
    "    per_device_eval_batch_size=batch_size,  # 每个设备上的评估批量大小\n",
    "    generation_max_length=128,  # 生成任务的最大长度\n",
    "    logging_steps=10,  # 指定日志记录的步骤，用于跟踪训练进度\n",
    "    remove_unused_columns=False,  # 是否删除不使用的列，以减少数据处理开销\n",
    "    label_names=[\"labels\"],  # 指定标签列的名称，用于训练过程中\n",
    "    # evaluation_strategy=\"steps\",\n",
    "    # eval_steps=25,\n",
    ")\n",
    "\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    args=training_args,\n",
    "    model=peft_model,\n",
    "    train_dataset=tokenized_common_voice[\"train\"],\n",
    "    eval_dataset=tokenized_common_voice[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=processor.feature_extractor,\n",
    ")\n",
    "peft_model.config.use_cache = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='30' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [30/30 12:14, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.500700</td>\n",
       "      <td>0.481170</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.243100</td>\n",
       "      <td>0.486947</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.167200</td>\n",
       "      <td>0.483523</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=30, training_loss=0.30368061463038126, metrics={'train_runtime': 758.71, 'train_samples_per_second': 2.531, 'train_steps_per_second': 0.04, 'total_flos': 4.087359995904e+18, 'train_loss': 0.30368061463038126, 'epoch': 3.0})"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/whisper-large-v2 - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "\n",
    "trainer.save_model(model_dir)\n",
    "trainer.save_state()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='2' max='2' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [2/2 00:04]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.49766218662261963,\n",
       " 'eval_runtime': 23.2109,\n",
       " 'eval_samples_per_second': 4.308,\n",
       " 'eval_steps_per_second': 0.086,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "small_test_dataset = tokenized_common_voice[\"validation\"].shuffle(seed=64).select(range(100))\n",
    "trainer.evaluate(small_test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.模型推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model_dir = \"/root/autodl-tmp/output/models/whisper-large-v2-asr-int8\"\n",
    "\n",
    "language = \"Chinese (China)\"\n",
    "language_abbr = \"zh-CN\"\n",
    "language_decode = \"chinese\"\n",
    "task = \"transcribe\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoProcessor\n",
    "from peft import PeftConfig, PeftModel\n",
    "\n",
    "peft_config = PeftConfig.from_pretrained(model_dir)\n",
    "\n",
    "base_model = AutoModelForSpeechSeq2Seq.from_pretrained(\n",
    "    peft_config.base_model_name_or_path, load_in_8bit=True, device_map=\"auto\"\n",
    ")\n",
    "\n",
    "peft_model = PeftModel.from_pretrained(base_model, model_dir)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
    "processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
    "feature_extractor = processor.feature_extractor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评估数据集处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the module from /root/.cache/huggingface/modules/datasets_modules/datasets/mozilla-foundation--common_voice_11_0/3f27acf10f303eac5b6fbbbe02495aeddb46ecffdb0a2fe3507fcfbf89094631 (last modified on Sat Feb 17 00:37:59 2024) since it couldn't be found locally at mozilla-foundation/common_voice_11_0, or remotely on the Hugging Face Hub.\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset, DatasetDict, Audio\n",
    "\n",
    "common_voice = DatasetDict()\n",
    "common_voice[\"test\"] = load_dataset(dataset_name, language_abbr, split=\"test\", trust_remote_code=True)\n",
    "common_voice = common_voice.remove_columns(\n",
    "    [\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"path\", \"segment\", \"up_votes\"]\n",
    ")\n",
    "common_voice = common_voice.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
    "\n",
    "def prepare_dataset(batch):\n",
    "    audio = batch[\"audio\"]\n",
    "    batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
    "    batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
    "    return batch\n",
    "\n",
    "small_common_voice = DatasetDict()\n",
    "\n",
    "small_common_voice[\"test\"] = common_voice[\"test\"].shuffle(seed=16).select(range(320))\n",
    "\n",
    "tokenized_common_voice = small_common_voice.map(prepare_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Union\n",
    "\n",
    "# 定义一个针对语音到文本任务的数据整理器类\n",
    "@dataclass\n",
    "class DataCollatorSpeechSeq2SeqWithPadding:\n",
    "    processor: Any  # 处理器结合了特征提取器和分词器\n",
    "\n",
    "    # 整理器函数，将特征列表处理成一个批次\n",
    "    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
    "        # 从特征列表中提取输入特征，并填充以使它们具有相同的形状\n",
    "        input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
    "        batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
    "\n",
    "        # 从特征列表中提取标签特征（文本令牌），并进行填充\n",
    "        label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
    "        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
    "\n",
    "        # 使用-100替换标签中的填充区域，-100通常用于在损失计算中忽略填充令牌\n",
    "        labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
    "\n",
    "        # 如果批次中的所有序列都以句子开始令牌开头，则移除它\n",
    "        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
    "            labels = labels[:, 1:]\n",
    "\n",
    "        # 将处理过的标签添加到批次中\n",
    "        batch[\"labels\"] = labels\n",
    "\n",
    "        return batch  # 返回最终的批次，准备好进行训练或评估\n",
    "\n",
    "# 用给定的处理器实例化数据整理器\n",
    "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 评估模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the module from /root/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--wer/85bee9e4216a78bb09b2d0d500f6af5c23da58f9210e661add540f5df6630fcd (last modified on Sat Feb 17 11:14:38 2024) since it couldn't be found locally at evaluate-metric--wer, or remotely on the Hugging Face Hub.\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]/root/miniconda3/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
      "100%|██████████| 20/20 [03:14<00:00,  9.74s/it]\n"
     ]
    }
   ],
   "source": [
    "import evaluate\n",
    "\n",
    "# 词错误率（WER）是评估ASR模型常用的指标。从 Evaluate加载 WER 指标\n",
    "metric = evaluate.load(\"wer\")\n",
    "batch_size=16\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import gc\n",
    "\n",
    "eval_dataloader = DataLoader(tokenized_common_voice[\"test\"], batch_size=batch_size, collate_fn=data_collator)\n",
    "\n",
    "# 遍历评估数据加载器中的所有批次\n",
    "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "    # 使用自动混合精度来加速计算，并减少显存使用\n",
    "    with torch.cuda.amp.autocast():\n",
    "        # 不计算梯度，以节省计算资源，仅用于推理和评估\n",
    "        with torch.no_grad():\n",
    "            # 生成预测的标记(tokens)，这里使用模型的generate函数进行文本生成\n",
    "            generated_tokens = (\n",
    "                peft_model.generate(\n",
    "                    input_features=batch[\"input_features\"].to(\"cuda\"),  # 将输入特征移动到GPU上\n",
    "                    decoder_input_ids=batch[\"labels\"][:, :4].to(\"cuda\"),  # 提供解码器的初始输入\n",
    "                    max_new_tokens=255,  # 设置生成的最大新标记数量\n",
    "                )\n",
    "                .cpu()  # 将生成的标记移回CPU\n",
    "                .numpy()  # 转换为NumPy数组以便进一步处理\n",
    "            )\n",
    "            # 获取批次中的标签，并将其移回CPU\n",
    "            labels = batch[\"labels\"].cpu().numpy()\n",
    "            # 将标签中的-100替换为填充标记的ID，-100通常用于忽略计算损失的标记\n",
    "            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "            # 使用分词器解码生成的标记和标签，以获得可读的文本\n",
    "            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
    "            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "            # 将预测和参考添加到评估指标中，用于后续的性能评估\n",
    "            metric.add_batch(\n",
    "                predictions=decoded_preds,\n",
    "                references=decoded_labels,\n",
    "            )\n",
    "    # 删除不再需要的变量以释放内存\n",
    "    del generated_tokens, labels, batch\n",
    "    # 手动触发垃圾收集，进一步清理内存\n",
    "    gc.collect()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "wer=65.9375%\n"
     ]
    }
   ],
   "source": [
    "# 计算词错误率（WER）指标，并将结果转换为百分比形式\n",
    "wer = 100 * metric.compute()\n",
    "\n",
    "# 打印词错误率，f\"{wer=}\"是一种格式化字符串的简洁写法，它会展示变量名和值\n",
    "print(f\"{wer=}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原始数据指标： wer=71.5625%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
